{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import argparse\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
    "instruments = 'csi300'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from collections import Counter\n",
    "from alphagen.data.expression import *\n",
    "from alphagen.models.alpha_pool import AlphaPool\n",
    "from alphagen.utils.correlation import batch_pearsonr, batch_spearmanr\n",
    "from alphagen_generic.features import *\n",
    "from gan.utils.data import get_data_by_year\n",
    "\n",
    "\n",
    "def pred_pool(capacity,data):\n",
    "    from alphagen_qlib.calculator import QLibStockDataCalculator\n",
    "    pool = AlphaPool(capacity=capacity,\n",
    "                    stock_data=data,\n",
    "                    target=target,\n",
    "                    ic_lower_bound=None)\n",
    "    exprs = []\n",
    "    for key in dict(Counter(cache).most_common(capacity)):\n",
    "        exprs.append(eval(key))\n",
    "    pool.force_load_exprs(exprs)\n",
    "    pool._optimize(alpha=5e-3, lr=5e-4, n_iter=2000)\n",
    "\n",
    "    exprs = pool.exprs[:pool.size]\n",
    "    weights = pool.weights[:pool.size]\n",
    "    calculator_test = QLibStockDataCalculator(data, target)\n",
    "    ensemble_value = calculator_test.make_ensemble_alpha(exprs, weights)\n",
    "    return ensemble_value\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Infer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in range(5):\n",
    "    for train_end in range(2016,2021):\n",
    "        for num in [1,10,20,50]:\n",
    "            save_dir = f'out_gp/{instruments}_{train_end}_day_{seed}' \n",
    "            print(save_dir)\n",
    "            \n",
    "            returned = get_data_by_year(\n",
    "                train_start = 2010,train_end=train_end,valid_year=train_end+1,test_year =train_end+2,\n",
    "                instruments=instruments, target=target,freq='day',\n",
    "            )\n",
    "            data_all,data,data_valid,data_valid_withhead,data_test,data_test_withhead,name = returned\n",
    "\n",
    "            cache = json.load(open(f'{save_dir}/40.json'))['cache']\n",
    "\n",
    "            features = ['open_', 'close', 'high', 'low', 'volume', 'vwap']\n",
    "            constants = [f'Constant({v})' for v in [-30., -10., -5., -2., -1., -0.5, -0.01, 0.01, 0.5, 1., 2., 5., 10., 30.]]\n",
    "            terminals = features + constants\n",
    "\n",
    "            pred = pred_pool(num,data)\n",
    "            pred = pred[-data_test.n_days:]\n",
    "            torch.save(pred.detach().cpu(),f\"{save_dir}/pred_{num}.pt\")\n",
    "            \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read and combine result to show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = []\n",
    "for num in [1]:\n",
    "    for seed in range(5):\n",
    "    \n",
    "        cur_seed_ic = []\n",
    "        cur_seed_ric = []\n",
    "        for train_end in range(2016,2021):\n",
    "                #'/path/to/save/results'\n",
    "                save_dir = f'out_gp/{instruments}_{train_end}_day_{seed}' \n",
    "\n",
    "                returned = get_data_by_year(\n",
    "                    train_start = 2010,train_end=train_end,valid_year=train_end+1,test_year =train_end+2,\n",
    "                    instruments=instruments, target=target,freq='day',\n",
    "                )\n",
    "                data_all,data,data_valid,data_valid_withhead,data_test,data_test_withhead,name = returned\n",
    "\n",
    "                pred = torch.load(f\"{save_dir}/pred_{num}.pt\").to('cuda:0')\n",
    "                \n",
    "                tgt = target.evaluate(data_test)\n",
    "                tgt = target.evaluate(data_all)[-data_test.n_days:,:]\n",
    "\n",
    "                ic_s = torch.nan_to_num(batch_pearsonr(pred,tgt),nan=0)\n",
    "                rank_ic_s = torch.nan_to_num(batch_spearmanr(pred,tgt),nan=0)\n",
    "\n",
    "                cur_seed_ic.append(ic_s)\n",
    "                cur_seed_ric.append(rank_ic_s)\n",
    "        \n",
    "        ic = torch.cat(cur_seed_ic)\n",
    "        rank_ic = torch.cat(cur_seed_ric)\n",
    "\n",
    "        ic_mean = ic.mean().item()\n",
    "        rank_ic_mean = rank_ic.mean().item()\n",
    "        ic_std = ic.std().item()\n",
    "        rank_ic_std = rank_ic.std().item()\n",
    "        tmp = dict(\n",
    "            seed = seed,\n",
    "            num = num,\n",
    "            ic = ic_mean,\n",
    "            ric = rank_ic_mean,\n",
    "            icir = ic_mean/ic_std,\n",
    "            ricir = rank_ic_mean/rank_ic_std,\n",
    "        )\n",
    "        result.append(tmp)\n",
    "            \n",
    "import pandas as pd\n",
    "print(pd.DataFrame(result).groupby(['num','seed']).mean().groupby('num').agg(['mean','std']))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
